{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sad_construct import *\n",
    "from sad_read_cifar10 import get_cifar10\n",
    "(x_train,y_train),(x_test,y_test)=get_cifar10()\n",
    "x_train=x_train.reshape(-1,3*32*32)\n",
    "x_test=x_test.reshape(-1,3*32*32)\n",
    "cifar10_train_data = np.append(x_train, y_train, axis=1)\n",
    "cifar10_test_data = np.append(x_test, y_test, axis=1)\n",
    "cifar10_shape=(3,32,32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1:\n",
      " Loss:0.621773440 \tAccuracy: 0.6688 [0.745, 0.873, 0.684, 0.529, 0.489, 0.589, 0.709, 0.699, 0.799, 0.572]\n",
      "epoch 2:\n",
      " Loss:0.604697813 \tAccuracy: 0.6963 [0.754, 0.752, 0.572, 0.494, 0.602, 0.525, 0.836, 0.8, 0.831, 0.797]\n",
      "epoch 3:\n",
      " Loss:0.585763745 \tAccuracy: 0.6877 [0.683, 0.749, 0.618, 0.461, 0.596, 0.669, 0.726, 0.783, 0.776, 0.816]\n",
      "epoch 4:\n",
      " Loss:0.570935702 \tAccuracy: 0.6849 [0.82, 0.744, 0.64, 0.401, 0.595, 0.567, 0.775, 0.729, 0.769, 0.809]\n",
      "epoch 5:\n",
      " Loss:0.552346228 \tAccuracy: 0.6946 [0.74, 0.825, 0.65, 0.429, 0.768, 0.525, 0.785, 0.669, 0.782, 0.773]\n",
      "epoch 6:\n",
      " Loss:0.540152195 \tAccuracy: 0.6867 [0.698, 0.842, 0.655, 0.458, 0.571, 0.513, 0.769, 0.74, 0.849, 0.772]\n",
      "epoch 7:\n",
      " Loss:0.525218568 \tAccuracy: 0.6886 [0.753, 0.688, 0.657, 0.557, 0.596, 0.487, 0.777, 0.765, 0.772, 0.834]\n",
      "epoch 8:\n",
      " Loss:0.512682451 \tAccuracy: 0.6869 [0.738, 0.697, 0.603, 0.539, 0.598, 0.61, 0.834, 0.668, 0.771, 0.811]\n",
      "epoch 9:\n",
      " Loss:0.495785392 \tAccuracy: 0.694 [0.762, 0.774, 0.593, 0.421, 0.656, 0.573, 0.781, 0.779, 0.846, 0.755]\n",
      "epoch 10:\n",
      " Loss:0.484911244 \tAccuracy: 0.6785 [0.694, 0.776, 0.681, 0.443, 0.602, 0.477, 0.836, 0.716, 0.813, 0.747]\n",
      "epoch 11:\n",
      " Loss:0.477713646 \tAccuracy: 0.6885 [0.71, 0.751, 0.617, 0.583, 0.583, 0.57, 0.794, 0.693, 0.784, 0.8]\n",
      "epoch 12:\n",
      " Loss:0.464877438 \tAccuracy: 0.6918 [0.742, 0.778, 0.592, 0.386, 0.613, 0.636, 0.853, 0.726, 0.822, 0.77]\n",
      "epoch 13:\n",
      " Loss:0.454125639 \tAccuracy: 0.6888 [0.695, 0.832, 0.601, 0.346, 0.658, 0.637, 0.796, 0.754, 0.81, 0.759]\n",
      "epoch 14:\n",
      " Loss:0.440214514 \tAccuracy: 0.682 [0.759, 0.806, 0.514, 0.328, 0.699, 0.585, 0.769, 0.838, 0.789, 0.733]\n",
      "epoch 15:\n",
      " Loss:0.434776100 \tAccuracy: 0.6869 [0.752, 0.806, 0.524, 0.487, 0.554, 0.571, 0.773, 0.759, 0.84, 0.803]\n",
      "epoch 16:\n",
      " Loss:0.424250075 \tAccuracy: 0.6738 [0.718, 0.722, 0.435, 0.413, 0.761, 0.721, 0.704, 0.65, 0.806, 0.808]\n",
      "epoch 17:\n",
      " Loss:0.417329437 \tAccuracy: 0.6931 [0.741, 0.838, 0.556, 0.49, 0.713, 0.533, 0.825, 0.748, 0.768, 0.719]\n",
      "epoch 18:\n",
      " Loss:0.406446836 \tAccuracy: 0.6877 [0.643, 0.776, 0.595, 0.486, 0.725, 0.545, 0.803, 0.645, 0.856, 0.803]\n",
      "epoch 19:\n",
      " Loss:0.395962938 \tAccuracy: 0.6898 [0.772, 0.714, 0.606, 0.376, 0.644, 0.654, 0.769, 0.724, 0.837, 0.802]\n",
      "epoch 20:\n",
      " Loss:0.391887865 \tAccuracy: 0.6749 [0.734, 0.843, 0.584, 0.365, 0.569, 0.624, 0.671, 0.766, 0.821, 0.772]\n"
     ]
    }
   ],
   "source": [
    "train_and_test(\n",
    "    mynetwork=cnn_cifar10,\n",
    "    train_data=cifar10_train_data,\n",
    "    test_data=cifar10_test_data,\n",
    "    data_shape=cifar10_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=20,\n",
    "    half_learning_rate_time=5,\n",
    "    pth=\"cnn_cifar10.npy\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init para\n",
      "epoch 1:\n",
      " Loss:1.922219159 \tAccuracy: 0.406 [0.492, 0.523, 0.175, 0.182, 0.536, 0.369, 0.544, 0.446, 0.371, 0.422]\n",
      "epoch 2:\n",
      " Loss:1.592144013 \tAccuracy: 0.462 [0.376, 0.682, 0.401, 0.233, 0.299, 0.471, 0.47, 0.604, 0.618, 0.466]\n",
      "epoch 3:\n",
      " Loss:1.478148090 \tAccuracy: 0.4798 [0.514, 0.77, 0.293, 0.216, 0.357, 0.493, 0.729, 0.561, 0.372, 0.493]\n",
      "epoch 4:\n",
      " Loss:1.409167035 \tAccuracy: 0.4947 [0.603, 0.5, 0.374, 0.272, 0.221, 0.521, 0.745, 0.588, 0.546, 0.577]\n",
      "epoch 5:\n",
      " Loss:1.352145361 \tAccuracy: 0.5003 [0.422, 0.63, 0.327, 0.31, 0.335, 0.307, 0.839, 0.666, 0.547, 0.62]\n",
      "epoch 6:\n",
      " Loss:1.298472515 \tAccuracy: 0.5339 [0.613, 0.755, 0.293, 0.355, 0.417, 0.304, 0.749, 0.634, 0.665, 0.554]\n",
      "epoch 7:\n",
      " Loss:1.245916749 \tAccuracy: 0.5524 [0.496, 0.683, 0.413, 0.279, 0.535, 0.469, 0.595, 0.707, 0.654, 0.693]\n",
      "epoch 8:\n",
      " Loss:1.195568580 \tAccuracy: 0.5634 [0.578, 0.78, 0.439, 0.38, 0.45, 0.483, 0.583, 0.618, 0.726, 0.597]\n",
      "epoch 9:\n",
      " Loss:1.152158322 \tAccuracy: 0.5733 [0.606, 0.685, 0.467, 0.262, 0.407, 0.567, 0.795, 0.552, 0.758, 0.634]\n",
      "epoch 10:\n",
      " Loss:1.113603944 \tAccuracy: 0.5846 [0.634, 0.748, 0.531, 0.354, 0.531, 0.291, 0.735, 0.57, 0.754, 0.698]\n",
      "epoch 11:\n",
      " Loss:1.080373576 \tAccuracy: 0.5945 [0.664, 0.706, 0.522, 0.358, 0.379, 0.481, 0.751, 0.661, 0.742, 0.681]\n",
      "epoch 12:\n",
      " Loss:1.046075358 \tAccuracy: 0.5927 [0.65, 0.638, 0.299, 0.365, 0.695, 0.534, 0.64, 0.62, 0.678, 0.808]\n",
      "epoch 13:\n",
      " Loss:1.018113104 \tAccuracy: 0.5858 [0.776, 0.814, 0.42, 0.398, 0.438, 0.418, 0.581, 0.713, 0.7, 0.6]\n",
      "epoch 14:\n",
      " Loss:0.991829370 \tAccuracy: 0.6059 [0.667, 0.74, 0.553, 0.27, 0.496, 0.49, 0.687, 0.722, 0.804, 0.63]\n",
      "epoch 15:\n",
      " Loss:0.965960775 \tAccuracy: 0.6033 [0.675, 0.561, 0.545, 0.422, 0.474, 0.411, 0.648, 0.71, 0.821, 0.766]\n",
      "epoch 16:\n",
      " Loss:0.939595968 \tAccuracy: 0.6003 [0.708, 0.72, 0.369, 0.526, 0.541, 0.432, 0.596, 0.648, 0.78, 0.683]\n",
      "epoch 17:\n",
      " Loss:0.915861357 \tAccuracy: 0.6105 [0.61, 0.74, 0.474, 0.403, 0.506, 0.501, 0.708, 0.639, 0.781, 0.743]\n",
      "epoch 18:\n",
      " Loss:0.895364062 \tAccuracy: 0.6169 [0.632, 0.731, 0.456, 0.393, 0.626, 0.469, 0.695, 0.649, 0.797, 0.721]\n",
      "epoch 19:\n",
      " Loss:0.873217901 \tAccuracy: 0.6138 [0.738, 0.746, 0.444, 0.459, 0.525, 0.474, 0.689, 0.728, 0.706, 0.629]\n",
      "epoch 20:\n",
      " Loss:0.853968613 \tAccuracy: 0.603 [0.64, 0.756, 0.439, 0.37, 0.523, 0.397, 0.852, 0.644, 0.748, 0.661]\n"
     ]
    }
   ],
   "source": [
    "train_and_test(\n",
    "    mynetwork=LeNet5_cifar10,\n",
    "    train_data=cifar10_train_data,\n",
    "    test_data=cifar10_test_data,\n",
    "    data_shape=cifar10_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=20,\n",
    "    half_learning_rate_time=5,\n",
    "    pth=\"LeNet5_cifar10.npy\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.6749 [0.734 0.843 0.584 0.365 0.569 0.624 0.671 0.766 0.821 0.772]\n",
      "Cross Accuracy:\n",
      "[[0.734 0.032 0.034 0.01  0.013 0.014 0.005 0.006 0.104 0.048]\n",
      " [0.015 0.843 0.    0.003 0.002 0.003 0.003 0.004 0.031 0.096]\n",
      " [0.086 0.013 0.584 0.044 0.068 0.072 0.038 0.047 0.037 0.011]\n",
      " [0.032 0.032 0.07  0.365 0.064 0.243 0.052 0.064 0.043 0.035]\n",
      " [0.033 0.011 0.105 0.046 0.569 0.055 0.035 0.113 0.026 0.007]\n",
      " [0.024 0.026 0.058 0.087 0.033 0.624 0.018 0.098 0.012 0.02 ]\n",
      " [0.011 0.043 0.066 0.047 0.047 0.051 0.671 0.017 0.017 0.03 ]\n",
      " [0.02  0.011 0.029 0.023 0.051 0.058 0.006 0.766 0.013 0.023]\n",
      " [0.068 0.051 0.006 0.008 0.005 0.004 0.001 0.004 0.821 0.032]\n",
      " [0.023 0.128 0.011 0.006 0.003 0.007 0.002 0.016 0.032 0.772]]\n"
     ]
    }
   ],
   "source": [
    "only_test(\n",
    "    mynetwork=cnn_cifar10,\n",
    "    test_data=cifar10_test_data,\n",
    "    data_shape=cifar10_shape,\n",
    "    onehotsize=10,\n",
    "    pth=\"cnn_cifar10.npy\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "init para\n",
      "epoch 1:\n",
      " Loss:1.839917913 \tAccuracy: 0.4077 [0.547, 0.471, 0.388, 0.123, 0.294, 0.406, 0.546, 0.301, 0.537, 0.464]\n",
      "epoch 2:\n",
      " Loss:1.643068802 \tAccuracy: 0.4387 [0.515, 0.609, 0.495, 0.115, 0.206, 0.371, 0.483, 0.467, 0.678, 0.448]\n",
      "epoch 3:\n",
      " Loss:1.556201690 \tAccuracy: 0.4539 [0.515, 0.531, 0.401, 0.334, 0.236, 0.327, 0.653, 0.502, 0.695, 0.345]\n",
      "epoch 4:\n",
      " Loss:1.494295085 \tAccuracy: 0.4769 [0.616, 0.655, 0.162, 0.239, 0.288, 0.481, 0.542, 0.726, 0.613, 0.447]\n",
      "epoch 5:\n",
      " Loss:1.442280935 \tAccuracy: 0.4825 [0.503, 0.458, 0.309, 0.254, 0.377, 0.529, 0.667, 0.541, 0.63, 0.557]\n",
      "epoch 6:\n",
      " loss:1.795122288 \tprocess: 67.96%\t   "
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[2], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m train_and_test(\n\u001b[0;32m      2\u001b[0m     mynetwork\u001b[38;5;241m=\u001b[39mFc_cifar10,\n\u001b[0;32m      3\u001b[0m     train_data\u001b[38;5;241m=\u001b[39mcifar10_train_data,\n\u001b[0;32m      4\u001b[0m     test_data\u001b[38;5;241m=\u001b[39mcifar10_test_data,\n\u001b[0;32m      5\u001b[0m     data_shape\u001b[38;5;241m=\u001b[39mcifar10_shape,\n\u001b[0;32m      6\u001b[0m     onehotsize\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m,\n\u001b[0;32m      7\u001b[0m     batch_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m16\u001b[39m,\n\u001b[0;32m      8\u001b[0m     epoch_num\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m20\u001b[39m,\n\u001b[0;32m      9\u001b[0m     half_learning_rate_time\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m,\n\u001b[0;32m     10\u001b[0m     pth\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfc_cifar10.npy\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m     11\u001b[0m )\n",
      "File \u001b[1;32md:\\Users\\Fox\\桌面\\sad_machine_learning\\sad_construct.py:22\u001b[0m, in \u001b[0;36mtrain_and_test\u001b[1;34m(mynetwork, train_data, test_data, data_shape, onehotsize, batch_size, epoch_num, learning_rate, half_learning_rate_time, pth)\u001b[0m\n\u001b[0;32m     20\u001b[0m loss\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39ml\n\u001b[0;32m     21\u001b[0m mynetwork\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m---> 22\u001b[0m mynetwork\u001b[38;5;241m.\u001b[39mupdate(learning_rate)\n\u001b[0;32m     23\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m epoch_id\u001b[38;5;241m%\u001b[39mhalf_learning_rate_time\u001b[38;5;241m==\u001b[39m\u001b[38;5;241m0\u001b[39m: learning_rate\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m\n\u001b[0;32m     24\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\r\u001b[39;00m\u001b[38;5;124m loss:\u001b[39m\u001b[38;5;132;01m%.9f\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m%\u001b[39ml,\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124mprocess:\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mint\u001b[39m(idx_batch\u001b[38;5;241m/\u001b[39m(train_data\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m/\u001b[39mbatch_size)\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m10000\u001b[39m)\u001b[38;5;241m/\u001b[39m\u001b[38;5;241m100\u001b[39m,end\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m%\u001b[39m\u001b[38;5;130;01m\\t\u001b[39;00m\u001b[38;5;124m   \u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "File \u001b[1;32md:\\Users\\Fox\\桌面\\sad_machine_learning\\sad_network.py:32\u001b[0m, in \u001b[0;36mNeuralNetwork.update\u001b[1;34m(self, learning_rate)\u001b[0m\n\u001b[0;32m     30\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mupdate\u001b[39m(\u001b[38;5;28mself\u001b[39m,learning_rate\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.01\u001b[39m):\n\u001b[0;32m     31\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlayers:\n\u001b[1;32m---> 32\u001b[0m         layer\u001b[38;5;241m.\u001b[39mupdate()\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "train_and_test(\n",
    "    mynetwork=Fc_cifar10,\n",
    "    train_data=cifar10_train_data,\n",
    "    test_data=cifar10_test_data,\n",
    "    data_shape=cifar10_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=20,\n",
    "    half_learning_rate_time=5,\n",
    "    pth=\"fc_cifar10.npy\"\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "face",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
